/**
 * 
 */
package edu.berkeley.nlp.PCFGLA;

import java.util.Arrays;
import java.util.List;

import edu.berkeley.nlp.PCFGLA.smoothing.Smoother;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;

/**
 * @author petrov
 *
 */
public class HierarchicalFullyConnectedLexicon extends HierarchicalLexicon {

	private static final long serialVersionUID = 1L;
	protected int knownWordCount;


	/**
	 * @param numSubStates
	 * @param threshold
	 */
	public HierarchicalFullyConnectedLexicon(short[] numSubStates, int knownWordCount) {
		super(numSubStates, 0);
		this.knownWordCount = knownWordCount;
	}
	
	public HierarchicalFullyConnectedLexicon(short[] numSubStates, int smoothingCutoff, double[] smoothParam, 
			Smoother smoother, StateSetTreeList trainTrees, int knownWordCount) {
  	this(numSubStates, knownWordCount);
  	init(trainTrees);
  }

	
  /**
	 * @param previousLexicon
	 */
	public HierarchicalFullyConnectedLexicon(SimpleLexicon previousLexicon, int knownWordCount) {
		super(previousLexicon);
		this.knownWordCount = knownWordCount;
	}
	
	public HierarchicalFullyConnectedLexicon newInstance() {
		return new HierarchicalFullyConnectedLexicon(this.numSubStates,this.knownWordCount);
	}

  public void init(StateSetTreeList trainTrees){
  	for (Tree<StateSet> tree : trainTrees){
  		List<StateSet> words = tree.getYield();
  		for (StateSet word : words){
				String sig = word.getWord();
				wordIndexer.add(sig);
  		}
  	}
  	wordCounter = new int[wordIndexer.size()];
  	for (Tree<StateSet> tree : trainTrees){
  		List<StateSet> words = tree.getYield();
  		int ind = 0;  		
  		for (StateSet word : words){
  			String wordString = word.getWord();
  			wordCounter[wordIndexer.indexOf(wordString)]++;
  			
				String sig = getSignature(word.getWord(), ind++);
				wordIndexer.add(sig);
  		}  		
  	}
  	
  	tagWordIndexer = new IntegerIndexer[numStates];
  	for (int tag=0; tag<numStates; tag++){
  		tagWordIndexer[tag] = new IntegerIndexer(wordIndexer.size());
  	}

  	labelTrees(trainTrees);

  	boolean[] lexTag = new boolean[numStates]; 
  	for (Tree<StateSet> tree : trainTrees){
  		List<StateSet> words = tree.getYield();
  		List<StateSet> tags = tree.getPreTerminalYield();
  		int ind = 0;
  		for (StateSet word : words){
  			int tag = tags.get(ind).getState();
				tagWordIndexer[tag].add(new Integer(word.wordIndex));
				tagWordIndexer[tag].add(new Integer(word.sigIndex));
				lexTag[tag] = true;
				ind++;
  		}  		
  	}


  	expectedCounts = new double[numStates][][];
  	scores = new double[numStates][][];
  	for (int tag=0; tag<numStates; tag++){
  		if (!lexTag[tag]) {
  			tagWordIndexer[tag] = null;
  			continue;
  		}
//  		else tagWordIndexer[tag] = tagIndexer;
//  		expectedCounts[tag] = new double[numSubStates[tag]][tagWordIndexer[tag].size()];
  		scores[tag] = new double[numSubStates[tag]][tagWordIndexer[tag].size()];
  	}
  	nWords = wordIndexer.size();
  }

	public double[] score(int globalWordIndex, int globalSigIndex, short tag, int loc, boolean noSmoothing, boolean isSignature) {
		double[] res = new double[numSubStates[tag]];
		if (globalWordIndex!=-1) {
			int tagSpecificWordIndex = tagWordIndexer[tag].indexOf(globalWordIndex);
			if (tagSpecificWordIndex!=-1){
				for (int i=0; i<numSubStates[tag]; i++){
					res[i] = scores[tag][i][tagSpecificWordIndex]; 
				}
			} else {
				Arrays.fill(res, 1.0);
			}
		} else {
			Arrays.fill(res, 1.0);
		}
		if (globalWordIndex>=0 && (wordCounter[globalWordIndex]>knownWordCount)) {
//			if (globalSigIndex!=-1) System.out.println("Problem: frequent word has signature!");
			return res;
		}
		if (globalSigIndex!=-1) {
			int tagSpecificWordIndex = tagWordIndexer[tag].indexOf(globalSigIndex);
			if (tagSpecificWordIndex!=-1){
				for (int i=0; i<numSubStates[tag]; i++){
					res[i] *= scores[tag][i][tagSpecificWordIndex]; 
				}
//			} else{
//				System.out.println("unseen sig-tag pair");
			}
//		} else{
//			System.out.println("unseen sig");
		}
//		if (smoother!=null) smoother.smooth(tag,res);
		return res;
  } 


	public double[] score(StateSet stateSet, short tag, boolean noSmoothing, boolean isSignature) {
		if (stateSet.wordIndex == -2) {
			String word = stateSet.getWord();
			if (isSignature){
				stateSet.wordIndex = -1;
				stateSet.sigIndex = wordIndexer.indexOf(word);
			} else {
				stateSet.wordIndex = wordIndexer.indexOf(word);
//				if (stateSet.wordIndex > wordCounter.length){
//					System.out.println("no count for this word: "+(String)wordIndexer.get(tagWordIndexer[tag].get(stateSet.wordIndex)));
//					stateSet.sigIndex = -1;
//				} else {
				if ((stateSet.wordIndex>=0 && (wordCounter[stateSet.wordIndex]>knownWordCount)) || noSmoothing)
					stateSet.sigIndex = -1;
				else if (knownWordCount > 0)
					stateSet.sigIndex = wordIndexer.indexOf(getSignature(word,stateSet.from));
				else 
					stateSet.wordIndex = wordIndexer.indexOf(getSignature(word,stateSet.from));
			}
//			}
		}
		return score(stateSet.wordIndex, stateSet.sigIndex, tag, stateSet.from, noSmoothing, isSignature);
	}
	
	
	public void labelTrees(StateSetTreeList trainTrees){
  	for (Tree<StateSet> tree : trainTrees){
  		List<StateSet> words = tree.getYield();
  		List<StateSet> tags = tree.getPreTerminalYield();
  		int ind = 0;
  		for (StateSet word : words){
  			word.wordIndex = wordIndexer.indexOf(word.getWord());
  			if (word.wordIndex<0 || word.wordIndex>=wordCounter.length){
  				System.out.println("Have never seen this word before: "+word.getWord()+" "+word.wordIndex);
  				System.out.println(tree);
  			}
  			else if (wordCounter[word.wordIndex]<=knownWordCount){
	  			short tag = tags.get(ind).getState();
					String sig = getSignature(word.getWord(), ind);
					wordIndexer.add(sig);
	  			word.sigIndex = wordIndexer.indexOf(sig);
					tagWordIndexer[tag].add(wordIndexer.indexOf(sig));
  			}
  			else 
  				word.sigIndex = -1;
				ind++;
  		}  		
  	}
	}
	
	


}
